Source code for hysop.problem

# Copyright (c) HySoP 2011-2024
#
# This file is part of HySoP software.
# See "https://particle_methods.gricad-pages.univ-grenoble-alpes.fr/hysop-doc/"
# for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import datetime
import sys

from hysop import __DEBUG__
from hysop.constants import HYSOP_DEFAULT_TASK_ID, Backend, MemoryOrdering
from hysop.core.checkpoints import CheckpointHandler
from hysop.core.graph.computational_graph import ComputationalGraph
from hysop.tools.contexts import Timer
from hysop.tools.decorators import debug, profile
from hysop.tools.parameters import MPIParams
from hysop.tools.string_utils import vprint, vprint_banner
from hysop.tools.htypes import check_instance, first_not_None, to_list, to_tuple


[docs] class Problem(ComputationalGraph): def __new__( cls, name=None, method=None, mpi_params=None, check_unique_clenv=True, **kwds ): return super().__new__(cls, **kwds) def __init__( self, name=None, method=None, mpi_params=None, check_unique_clenv=True, **kwds ): mpi_params = first_not_None( mpi_params, MPIParams() ) # enforce mpi params for problems super().__init__(name=name, method=method, mpi_params=mpi_params, **kwds) self._do_check_unique_clenv = check_unique_clenv self.search_intertasks_ops = None self.ops_tasks = []
[docs] @debug def insert(self, *ops): for node in ops: if hasattr(node, "mpi_params") and node.mpi_params: self.ops_tasks.append(node.mpi_params.task_id) if hasattr(node, "impl_kwds") and "mpi_params" in node.impl_kwds: self.ops_tasks.append(node.impl_kwds["mpi_params"].task_id) given_ops_have_tasks = True pb_task_id = ( HYSOP_DEFAULT_TASK_ID if self.mpi_params is None else self.mpi_params.task_id ) if len(set(self.ops_tasks)) == 1 and self.ops_tasks[0] == pb_task_id: # Intertask is not needed this is a single task-problem given_ops_have_tasks = False self.search_intertasks_ops = given_ops_have_tasks self.push_nodes(*ops) return self
[docs] @debug def build( self, args=None, allow_subbuffers=False, outputs_are_inputs=True, search_intertasks_ops=None, ): with Timer() as tm: msg = self.build_problem( args=args, allow_subbuffers=allow_subbuffers, outputs_are_inputs=outputs_are_inputs, search_intertasks_ops=search_intertasks_ops, ) if msg: msg = f" Problem {msg} achieved, exiting ! " vprint_banner(msg, at_border=2) sys.exit(0) comm = self.mpi_params.comm if (not self.domain is None) and self.domain.has_tasks: comm = self.domain.parent_comm size = comm.Get_size() avg_time = comm.allreduce(tm.interval) / size msg = " Problem building took {} ({}s)" if size > 1: msg += f", averaged over {size} ranks. " msg = msg.format(datetime.timedelta(seconds=round(avg_time)), avg_time) vprint_banner(msg, spacing=True, at_border=2) if (args is not None) and args.stop_at_build: msg = " Problem has been built, exiting. " vprint_banner(msg, at_border=2) sys.exit(0)
[docs] def get_preserved_input_fields(self): return set()
[docs] def build_problem( self, args, allow_subbuffers, outputs_are_inputs=True, search_intertasks_ops=None, ): if (args is not None) and args.stop_at_initialization: return "initialization" vprint("\nInitializing problem... " + str(self.name)) search_intertasks = search_intertasks_ops if search_intertasks is None: search_intertasks = self.search_intertasks_ops self.initialize( outputs_are_inputs=outputs_are_inputs, topgraph_method=None, is_root=True, search_intertasks_ops=search_intertasks, ) if (args is not None) and args.stop_at_discretization: return "discretization" vprint("\nDiscretizing problem... " + str(self.name)) for node in [_ for _ in self.nodes if isinstance(_, Problem)]: node.discretize() self.discretize() if (args is not None) and args.stop_at_work_properties: return "work properties retrieval" vprint("\nGetting work properties... " + str(self.name)) work = self.get_work_properties() if (args is not None) and args.stop_at_work_allocation: return "work allocation" vprint("\nAllocating work... " + str(self.name)) work.allocate(allow_subbuffers=allow_subbuffers) if (args is not None) and args.stop_at_setup: return "setup" vprint("\nSetting up problem..." + str(self.name)) self.setup(work)
[docs] def discretize(self): super().discretize() if self._do_check_unique_clenv: self.check_unique_clenv()
[docs] def check_unique_clenv(self): cl_env, first_op = None, None for op in self.nodes: for topo in set(op.input_fields.values()).union( set(op.output_fields.values()) ): if topo is not None and (topo.backend.kind == Backend.OPENCL): if cl_env is None: first_op = op cl_env = topo.backend.cl_env elif topo.backend.cl_env is not cl_env: msg = "" msg += "\nOpenCl environment mismatch between operator {} and operator {}." msg = msg.format(first_op.name, op.name) msg += f"\n{cl_env}" msg += "\n and" msg += f"\n{topo.backend.cl_env}" msg += "\n If this is required, override check_unique_clenv()." raise RuntimeError(msg)
[docs] def initialize_field(self, field, mpi_params=None, **kwds): """Initialize a field on all its input and output topologies.""" initialized = set() def __iterate_nodes(l): for e in l: if isinstance(e, Problem): yield from __iterate_nodes(e.nodes) yield e # give priority to tensor field initialization for op_fields in ( self.input_discrete_tensor_fields, self.output_discrete_tensor_fields, ) + tuple( _ for op in __iterate_nodes(self.nodes) for _ in ( op.input_discrete_tensor_fields, op.output_discrete_tensor_fields, op.input_discrete_fields, op.output_discrete_fields, ) ): if field in op_fields: dfield = op_fields[field] if all((df in initialized) for df in dfield.discrete_fields()): # all contained scalar fields were already initialized continue elif mpi_params and not all( [ mpi_params.task_id == df.topology.task_id for df in dfield.discrete_fields() ] ): # Topology task does not matches given mpi_params task continue else: components = () for component, scalar_dfield in dfield.nd_iter(): if scalar_dfield._dfield not in initialized: components += (component,) dfield.initialize(components=components, **kwds) initialized.update(dfield.discrete_fields()) if not initialized: msg = f"FATAL ERROR: Could not initialize field {field.name}." raise RuntimeError(msg)
[docs] @debug @profile def solve( self, simu, dry_run=False, dbg=None, report_freq=10, plot_freq=10, checkpoint_handler=None, **kwds, ): if dry_run: vprint() vprint_banner("** Dry-run requested, skipping simulation. **") return simu.initialize() check_instance(checkpoint_handler, CheckpointHandler, allow_none=True) if not checkpoint_handler is None: checkpoint_handler.create_checkpoint_template(self, simu) checkpoint_handler.load_checkpoint(self, simu) vprint("\nSolving problem...") with Timer() as tm: while not simu.is_over: vprint() simu.print_state() self.apply(simulation=simu, dbg=dbg, **kwds) should_dump_checkpoint = ( not checkpoint_handler is None ) and checkpoint_handler.should_dump( simu ) # determined before simu advance simu.advance(dbg=dbg, plot_freq=plot_freq) if should_dump_checkpoint: checkpoint_handler.save_checkpoint(self, simu) if report_freq and (simu.current_iteration % report_freq) == 0: self.profiler_report() comm = self.mpi_params.comm if (not self.domain is None) and self.domain.has_tasks: comm = self.domain.parent_comm size = comm.Get_size() avg_time = comm.allreduce(tm.interval) / size msg = " Simulation took {} ({}s)" if size > 1: msg += f", averaged over {size} ranks. " msg += "\n for {} iterations ({}s per iteration) " msg = msg.format( datetime.timedelta(seconds=round(avg_time)), avg_time, max(simu.current_iteration + 1, 1), avg_time / max(simu.current_iteration + 1, 1), ) vprint_banner(msg, spacing=True, at_border=2) simu.finalize() if not checkpoint_handler is None: checkpoint_handler.finalize(self.mpi_params) self.final_report() if dbg is not None: dbg("final iteration", nostack=True)
[docs] def final_report(self): self.profiler_report() if self.is_root or __DEBUG__ or self.__FORCE_REPORTS__: vprint(self.task_profiler_report())
[docs] @debug def finalize(self): vprint("Finalizing problem...") super().finalize()